from collections.abc import Callable, Generator, Sequence  # noqa: I001
from contextlib import contextmanager
from typing import Any, Literal, TypeAlias

from numpy.random import Generator as NumpyGenerator
from numpy.typing import ArrayLike, NDArray
from torch import Generator as TorchGenerator
from torch import Tensor

from optiland._types import BEArrayT, ScalarOrArrayT
from optiland.backend.torch_backend import GradMode
from optiland.backend import linalg  # noqa: F401

__all__ = [
    "linalg",
]

ndarray: TypeAlias = NDArray | Tensor  # noqa: PYI042
inf: float

def array_equal(a: BEArrayT, b: BEArrayT) -> bool: ...
def isinf(x: ScalarOrArrayT) -> ScalarOrArrayT: ...
def isnan(x: ScalarOrArrayT) -> ScalarOrArrayT: ...
def set_backend(name: str) -> None: ...
def get_backend() -> str: ...
def list_available_backends() -> list[str]: ...

# Functions for the torch backend
def set_device(device: Literal["cpu", "cuda"]) -> None: ...
def get_device() -> Literal["cpu", "cuda"]: ...
def set_precision(precision: Literal["float32", "float64"]) -> None: ...
def get_precision() -> Literal["float32", "float64"]: ...

grad_mode: GradMode

# Functions that are implemented by each backend
def array(x: ArrayLike) -> ndarray: ...
def zeros(shape: Sequence[int]) -> ndarray: ...
def ones(shape: Sequence[int]) -> ndarray: ...
def full(shape: Sequence[int], fill_value: float) -> ndarray: ...
def linspace(start: float, stop: float, num: int) -> ndarray: ...
def arange(*args: float | ndarray, step: float | ndarray = 1) -> ndarray: ...
def zeros_like(x: ArrayLike) -> ndarray: ...
def ones_like(x: ArrayLike) -> ndarray: ...
def load(filename: str) -> ndarray: ...
def cast(x: ArrayLike) -> ndarray: ...
def copy(x: BEArrayT) -> BEArrayT: ...
def is_array_like(x: Any) -> bool: ...
def isfinite(x: ScalarOrArrayT) -> ScalarOrArrayT: ...
def newaxis() -> None: ...
def shape(x: ArrayLike) -> tuple[int, ...]: ...
def size(x: ArrayLike) -> int: ...
def reshape(x: BEArrayT, shape: Sequence[int]) -> BEArrayT: ...
def stack(xs: Sequence[ArrayLike], axis: int = 0) -> ndarray: ...
def broadcast_to(x: BEArrayT, shape: Sequence[int]) -> BEArrayT: ...
def repeat(x: BEArrayT, repeats: int) -> BEArrayT: ...
def flip(x: BEArrayT) -> BEArrayT: ...
def meshgrid(*arrays: BEArrayT) -> tuple[BEArrayT, ...]: ...
def roll(
    x: BEArrayT, shift: int | Sequence[int], axis: int | tuple[int, ...] = ()
) -> BEArrayT: ...
def unsqueeze_last(x: BEArrayT) -> BEArrayT: ...
def tile(x: BEArrayT, dims: int | Sequence[int]) -> BEArrayT: ...
def isscalar(x: ArrayLike | ndarray) -> bool: ...

_Generator: TypeAlias = NumpyGenerator | TorchGenerator

def default_rng(seed: int | None = None) -> _Generator: ...
def random_uniform(
    low: float = 0,
    high: float = 0,
    size: int | None = None,
    generator: _Generator | None = None,
) -> ndarray: ...
def random_normal(
    loc: float = 0,
    scale: float = 1,
    size: Sequence[int] | None = None,
    generator: _Generator | None = None,
) -> ndarray: ...

# Mathematical operations
def sqrt(x: ArrayLike) -> ndarray: ...
def sin(x: ArrayLike) -> ndarray: ...
def cos(x: ArrayLike) -> ndarray: ...
def tan(x: ArrayLike) -> ndarray: ...
def exp(x: ArrayLike) -> ndarray: ...
def log2(x: ArrayLike) -> ndarray: ...
def abs(x: ArrayLike) -> ndarray: ...
def radians(x: ArrayLike) -> ndarray: ...
def degrees(x: ArrayLike) -> ndarray: ...
def deg2rad(x: ArrayLike) -> ndarray: ...
def rad2deg(x: ArrayLike) -> ndarray: ...
def max(x: ArrayLike) -> int | float: ...
def min(x: ArrayLike) -> int | float: ...
def maximum(a: ArrayLike, b: ArrayLike) -> ndarray: ...
def nanmax(
    x: ArrayLike, axis: int | None = None, keepdims: bool = False
) -> ndarray: ...
def mean(x: ArrayLike, axis: int | None = None, keepdims: bool = False) -> ndarray: ...
def std(x: ArrayLike, axis: int | tuple[int, ...] | None = None) -> float: ...
def sum(x: ArrayLike, axis: int | tuple[int, ...] | None = None) -> float: ...
def all(x: bool | ArrayLike) -> bool: ...
def any(x: bool | ArrayLike) -> bool: ...
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> ndarray: ...
def factorial(n: ArrayLike) -> ndarray: ...
def histogram2d(
    x: BEArrayT, y: BEArrayT, bins: Sequence[BEArrayT]
) -> tuple[BEArrayT, BEArrayT, BEArrayT]: ...
def get_bilinear_weights(
    coords: BEArrayT, bin_edges: Sequence[BEArrayT]
) -> tuple[BEArrayT, BEArrayT]: ...
def copy_to(source: BEArrayT, destination: BEArrayT) -> None: ...

# Linear algebra
def matmul(a: BEArrayT, b: BEArrayT) -> BEArrayT: ...
def batched_chain_matmul3d(a: BEArrayT, b: BEArrayT, c: BEArrayT) -> BEArrayT: ...
def cross(a: BEArrayT, b: BEArrayT) -> BEArrayT: ...
def matrix_vector_multiply_and_squeeze(p: BEArrayT, E: BEArrayT) -> BEArrayT: ...
def to_complex(x: BEArrayT) -> BEArrayT: ...
def mult_p_E(p: BEArrayT, E: BEArrayT) -> BEArrayT: ...

# Interpolation
def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike) -> ndarray: ...
def nearest_nd_interpolator(
    points: BEArrayT, values: BEArrayT, Hx: BEArrayT, Hy: BEArrayT
) -> BEArrayT: ...

# Polynomial operations
def polyfit(x: BEArrayT, y: BEArrayT, degree: int) -> BEArrayT: ...
def polyval(coeffs: Sequence[float], x: float | ndarray) -> float | ndarray: ...

# Padding
def pad(
    x: BEArrayT,
    pad_width: int,
    mode: Literal["constant"] = "constant",
    constant_values: float | None = 0,
) -> BEArrayT: ...

# Vectorization
def vectorize(pyfunc: Callable[..., Any]) -> Callable[[ndarray], ndarray]: ...

# Conversion and utilities
def atleast_1d(x: ArrayLike) -> ndarray: ...
def atleast_2d(x: ArrayLike) -> ndarray: ...
def as_array_1d(data: ArrayLike) -> ndarray: ...
def eye(n: int) -> ndarray: ...

# Error State Context
@contextmanager
def errstate(**kwargs: Any) -> Generator[Any, Any, Any]: ...

# Miscellaneous Utilities
def path_contains_points(vertices: BEArrayT, points: BEArrayT) -> BEArrayT: ...
def to_numpy(obj: ArrayLike) -> NDArray: ...
